syntax = "proto3";

package tensorflow.tpu;

// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
// proto contained in the TPUEmbeddingConfiguration.

// The embedding output consists of a list of tensors, each specified by an
// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
// field). Each table and feature lookup is then placed into some number of
// particular positions within some output tensor (identified by "tensor_index"
// within OutputLocation). The tree of table lookups, feature lookups, and
// output locations is specified by the
// "table(table_id).feature(feature_id).output_location" repeated fields within
// EmbeddingOutputLayout.

message TPUEmbeddingOutputLayout {
  // Location of one copy of the feature's data.
  message OutputLocation {
    // Which output tensor this copy of the feature will go into. Must be
    // between 0 and layout.output_size().
    int32 tensor_index = 1;

    // Offset in dimension 0 for this feature copy. Must be between 0 and
    // layout.output(tensor_index).dim0_size_per_sample().
    int32 dim0_offset = 2;

    // Offset in dimension 1 for this feature copy. Must be between 0 and
    // layout.output(tensor_index).dim1_size() - table width; repeated or
    // partially/fully overlapping values are allowed and results in the same
    // range will be summed (with the gradients replicated in the backward
    // pass).
    int32 dim1_offset = 3;
  }

  // Description of the output placement for one feature.
  message FeatureDescriptor {
    // Typically, only one copy of each feature is used, but multiple are
    // allowed and the same data will be copied to all of them (with the
    // gradients summed in the backward pass).
    repeated OutputLocation output_location = 1;
  }

  // Description of the output placement for features of one table.
  message TableDescriptor {
    // Output locations for each feature loaded from this table.
    repeated FeatureDescriptor feature = 1;
  }
  // Output locations for each feature of each table.
  repeated TableDescriptor table = 1;

  // Data layout and shape computation information for a single output tensor.
  // Any unused locations in the tensor will be filled with zeros, and
  // corresponding gradients will be ignored.

  // Size and layout information for 2-D tensors.
  message TwoDOutputTensor {
    // Multiplier for output dimension 0 size; used to match legacy format that
    // stacks features within a sample in dimension 0.
    int32 dim0_size_per_sample = 2;

    // The size (in dimension 1) of this output tensor.
    int32 dim1_size = 1;
  }

  // Format information for a single output tensor.
  message EmbeddingOutputTensor {
    oneof output_format {
      TwoDOutputTensor two_d = 4;
    }
  }

  // Shape and layout information for each tensor.
  repeated EmbeddingOutputTensor output = 2;
}
